#!/usr/bin/env python
# print("importing libraries")
from tartan_logging import TartanLogging
import numpy as np
import argparse
import dill
import matplotlib
matplotlib.use('pdf')
import matplotlib.pyplot as plt
import cv2
import pylab as pl
import math
from sm import PlotCollection
from matplotlib.backends.backend_pdf import PdfPages
import cv2 


def parseArgs():
    parser = argparse.ArgumentParser()
    parser.add_argument('-obsfile','--obsfile',required=True,help="The log pickle you want to use.")
    parser.add_argument('-obs_it','--obs_it',required=False,default=-1,help="The idx (i.e., tartan iteration) you want to use for evaluation. Default =-1, meaning we use the last iteration.")
    parser.add_argument('-cams','--cams',nargs='+',required=True,help="Provide the pickles of camera models to test.")
    parser.add_argument('-cams_it','--cams_it',nargs='+',required=True,help="Provide the iteration from each pickle you want to use for visualization. E.g. if you want to compare two runs and for each the final trained model, enter -1 -1.")
    parser.add_argument('-cam_names','--cam_names',nargs='+',help="Provide names for all cams to print on the plots.")
    parser.add_argument('-camids','--camids',nargs='+',default=0)
    parser.add_argument('-nbins','--nbins',required=False,default=100,help="This is the number of bins for ")
    parser.add_argument('-babel_args','--babel_args',nargs='+',default=0)

    return parser.parse_args()

def get_outliers(array, num_stds):
    num_points = np.shape(array)[1]
    std = np.std(array, 0, dtype=np.float)
    delete_list = []
    for i in range(num_points):
        if (np.abs(array[0,i]) > num_stds*std[0] or np.abs(array[1,i]) > num_stds*std[1]):
            delete_list.append(i)
    
    return delete_list

def rolling_window(a, window):
    pad = np.ones(len(a.shape), dtype=np.int32)
    pad[-1] = window-1
    pad = list(zip(pad, np.zeros(len(a.shape), dtype=np.int32)))
    a = np.pad(a, pad,mode='reflect')
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)
    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)

def plotHistograms(evalObs,evalCams,camNames):
    num_evals = len(evalCams)
    for i in range(num_evals):
            cam = evalCams[i]
            obslist = evalObs[i]
            err = cam.getReprojectionErrors(evalObs[i])

            # pl.hist(orig_err[0,:],alpha=0.5,bins=np.arange(-5.0,5.0,0.1))
            # pl.hist(orig_err[1,:],alpha=0.5,bins=np.arange(-5.0,5.0,0.1))

            # pl.hist(final_err[0,:],alpha=0.5,bins=np.arange(-5.0,5.0,0.1))
            # pl.hist(final_err[1,:],alpha=0.5,bins=np.arange(-5.0,5.0,0.1))
def plotSphericalError(evalObs,evalCams,camNames,parsed):
    num_evals = len(evalCams)
    if(parsed.babel_args != 0):
        num_evals += 1

    for i in range(num_evals):
        if (i == (num_evals-1) and parsed.babel_args != 0):
            cam = evalCams[i-1]
            obslist = evalObs[i-1]
            err = np.array(cam.getReprojectionErrors(evalObs[i-1]))
            print("Previous params " + str(cam.projection().getParameters()))
            babel_float = np.array([float(babel_arg) for babel_arg in parsed.babel_args])
            cam.projection().setParameters(babel_float)
            print("Settings to : "+ str(cam.projection().getParameters()))
        else:
            cam = evalCams[i]
            print(str(cam.projection().getParameters()))
            obslist = evalObs[i]
            err = np.array(cam.getReprojectionErrors(evalObs[i]))
        print("err shape:"+str(np.shape(err)))
        # delete_idxs = get_outliers(err,4)
        # err = np.delete(err,delete_idxs,1)
        num_points = np.shape(err)[0]
        print("err shape:"+str(np.shape(err)))
        num_points = len(err)
        err = np.transpose(np.reshape(err,(num_points,2)))
        mag_err = np.array([np.linalg.norm(err[:,q]) for q in range(num_points)])
        # remove outliers
        

        # get polar angles
        img_points_x = np.concatenate([obs.getCornersImageFrame()[:,0] for obs in obslist ])
        img_points_y = np.concatenate([obs.getCornersImageFrame()[:,1] for obs in obslist ])

        counter = 0
        print("NUM POINTS: "+str(num_points))
        polars = []
        for j in range(num_points):
            ec_pnt = cam.keypointToEuclidean(np.array([[img_points_x[j]],[img_points_y[j]]])) 
            polars.append(np.rad2deg(np.arctan2(np.linalg.norm(ec_pnt[:2]),ec_pnt[-1])))

        polars = np.array(polars)
        

        polar_idxs = np.argsort(polars)
        print("Stats for {0}, mean: {1}, std:{2}.".format(camNames[i],str(np.mean(np.array(err), 1, dtype=np.float)),str(np.std(np.array(err), 1, dtype=np.float))))
        err = np.array([np.linalg.norm(err[:,q]) for q in range(num_points)])

        polar_min = np.min(polars)
        polar_max = np.max(polars)
        bins = np.linspace(polar_min,polar_max,parsed.nbins+1)
        bin_idxs = np.digitize(polars,bins)-1
        bin_avg_errs = np.zeros(parsed.nbins)
        empty_idxs = []
        full_idxs = []
        for k in range(parsed.nbins):
            if len(np.where(bin_idxs==k)[0]!=0):
                bin_avg_errs[k] = np.mean(err[np.where(bin_idxs==k)])
                full_idxs.append(k)
            else:
                empty_idxs.append(k)
        
        bin_avg_errs = np.delete(bin_avg_errs,empty_idxs)
        # pl.hist(bin_avg_errs)

        matplotlib.rc('xtick', labelsize=20) 
        matplotlib.rc('ytick', labelsize=20)
        # pl.plot(polars[polar_idxs],np.convolve(err[polar_idxs], np.ones(100)/100, mode='same'))
        pl.plot([np.mean([bins[q],bins[q+1]]) for q in full_idxs],bin_avg_errs)
        pl.xlabel('Polar Angle', fontsize=20)
        pl.ylabel('Reprojection Error', fontsize=20)





def main():
    parsed = parseArgs()

    assert len(parsed.cams) == len(parsed.cams_it)
    num_evals = len(parsed.cams)

    with open(parsed.obsfile, 'r') as f:
        tartanObsLog = dill.load(f)
    
    tartanObs = [log.ObservationDatabase_ for log in tartanObsLog]
    print("obs iteration:"+str(parsed.obs_it))
    evalObsDb = tartanObs[int(parsed.obs_it)]
    evalObs = [evalObsDb.observations[int(parsed.camids[i])] for i in range(num_evals)]

    evalCams = []
    
    if type(parsed.camids) != list:
        cam_list = []
        for _ in range(num_evals):
            cam_list.append(int(parsed.camids))
    else:
        cam_list = list(map(int,parsed.camids))
    
    for i, camlog in enumerate(parsed.cams):
        with open(camlog, 'r') as f:
            camlog_loaded = dill.load(f)

        evalCams.append(camlog_loaded[int(parsed.cams_it[i])].camera_[cam_list[i]])


    plotSphericalError(evalObs,evalCams,parsed.cam_names,parsed)
    plt.savefig('all_reproj_err.png')


if __name__ == "__main__":
    main()
